{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## FastBERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "!pip install fast-bert"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import logging\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from transformers import BertTokenizer\n",
    "from fast_bert.data_cls import BertDataBunch\n",
    "from fast_bert.learner_cls import BertLearner\n",
    "from fast_bert.metrics import accuracy"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Before execution create directories with `mkdir twitterdata labels`\n",
    "\n",
    "Then set paths:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "PATH_TO_DATA = \"./twitterdata/\"\n",
    "PATH_TO_LABELS = \"./labels/\"\n",
    "OUTPUT_DIR = \"./\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Read relevant data from Chapter 5, split data set (60/20/20) and save data sets as csv"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "df = pd.read_csv('../chapter5/train-processed.csv', encoding='latin-1')\n",
    "df = df.drop(df.columns[[0, 1, 2, 3, 4, 6]], axis=1)\n",
    "df.columns = ['text', 'label']\n",
    "\n",
    "# https://stackoverflow.com/questions/38250710/\n",
    "# how-to-split-data-into-3-sets-train-validation-and-test/38251213#38251213\n",
    "np.random.seed(0)\n",
    "train, valid, test = \\\n",
    "                np.split(df.sample(frac=1), [int(.6*len(df)), int(.8*len(df))])\n",
    "\n",
    "train.to_csv('./twitterdata/train.csv', index=False)\n",
    "valid.to_csv('./twitterdata/valid.csv', index=False)\n",
    "test.to_csv('./twitterdata/test.csv', index=False)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Get labels and save them in separate directory `labels`/`PATH_TO_LABELS` as csv"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "labels = pd.DataFrame(df.label.unique())\n",
    "labels.to_csv(\"./labels/labels.csv\", header=False, index=False)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "Define and train model"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "source": [
    "device = torch.device('cuda')\n",
    "logger = logging.getLogger()\n",
    "metrics = [{'name': 'accuracy', 'function': accuracy}]\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',\n",
    "                                          do_lower_case=True)\n",
    "\n",
    "databunch = BertDataBunch(PATH_TO_DATA,\n",
    "                          PATH_TO_LABELS,\n",
    "                          tokenizer,\n",
    "                          train_file=\"train.csv\",\n",
    "                          val_file=\"valid.csv\",\n",
    "                          test_data=\"test.csv\",\n",
    "                          text_col=0, label_col=1,\n",
    "                          batch_size_per_gpu=32,\n",
    "                          max_seq_length=140,\n",
    "                          multi_gpu=False,\n",
    "                          multi_label=False,\n",
    "                          model_type=\"bert\")\n",
    "\n",
    "learner = BertLearner.from_pretrained_model(databunch,\n",
    "                                            'bert-base-uncased',\n",
    "                                            metrics=metrics,\n",
    "                                            device=device,\n",
    "                                            logger=logger,\n",
    "                                            output_dir=OUTPUT_DIR,\n",
    "                                            is_fp16=False,\n",
    "                                            multi_gpu=False,\n",
    "                                            multi_label=False)\n",
    "\n",
    "learner.fit(3, lr=1e-2)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "source": [
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}